# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert coco data to mindrecord format."""

import json
import os

import numpy as np
import pycocotools.coco as coco
import scipy.io as sio
from mindspore.mindrecord import FileWriter

from mindvision.engine.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.DATASET)
class Centerface2MindRecord:
    """
    convert widerface dataset to mind record file.

    Args:
        data_root (str) : The image dir.
        ann_file (str) : The ann file dir.
        mindrecord_file (str) : The output mindrecord file
    """

    def __init__(self, img_dir, annot_path, mindrecord_dir, split, max_objs=64):
        """Constructor for Centerface2MindRecord"""
        self.img_dir = img_dir
        self.annot_path = annot_path
        self.mindrecord_dir = mindrecord_dir
        self.split = split
        self.max_objs = max_objs

    def create_centerface_train_label(self):
        """Get image path and annotation from COCO."""
        print('==> getting centerface key point {} data.'.format(self.split))
        centercoco = coco.COCO(self.annot_path)
        image_ids = centercoco.getImgIds()
        images_path = []
        image_anno_dict = {}
        for img_id in image_ids:
            idxs = centercoco.getAnnIds(imgIds=[img_id])
            if idxs:
                file_name = centercoco.loadImgs(ids=[img_id])[0]['file_name']
                img_path = os.path.join(self.img_dir, file_name)
                ann_ids = centercoco.getAnnIds(imgIds=[img_id])
                anns = centercoco.loadAnns(ids=ann_ids)
                num_objs = len(anns)
                if num_objs > self.max_objs:
                    num_objs = self.max_objs
                    anns = np.random.choice(anns, num_objs)
                target = []
                for ann in anns:
                    tmp = []
                    tmp.extend(ann['bbox'])
                    tmp.extend(ann['keypoints'])
                    target.append(tmp)
                images_path.append(img_path)
                image_anno_dict[img_path] = target
        num_samples = len(images_path)
        print('Loaded {} {} samples'.format(self.split, num_samples))  # Loaded train 12671 samples
        return images_path, image_anno_dict

    def create_centerface_test_label(self):
        """Get image path."""
        print('==> getting centerface key point {} data.'.format(self.split))
        ground_truth_mat = sio.loadmat(self.annot_path)
        event_list = ground_truth_mat['event_list']
        file_list = ground_truth_mat['file_list']
        img_id = 0
        images_path = []
        img_file_dict = {}
        img_name_dict = {}
        img_id_dict = {}
        for index, event in enumerate(event_list):
            file_list_item = file_list[index][0]
            im_dir = event[0][0]
            for file_obj in enumerate(file_list_item):
                im_name = file_obj[0][0]
                img_id = img_id + 1
                zip_name = '%s/%s.jpg' % (im_dir, im_name)
                img_path = os.path.join(self.img_dir, zip_name)
                images_path.append(img_path)
                img_file_dict[img_path] = im_dir
                img_name_dict[img_path] = im_name
                img_id_dict[img_path] = img_id
        num_samples = len(images_path)
        print('Loaded {} {} samples'.format(self.split, num_samples))  # Loaded train 12671 samples
        return images_path, img_file_dict, img_name_dict, img_id_dict

    def train_data_to_mindrecord_byte_image(self, prefix="centerface.mindrecord", file_num=8):
        """Create MindRecord file."""
        if not os.path.exists(self.mindrecord_dir):
            os.mkdir(self.mindrecord_dir)

        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        image_files, image_anno_dict = self.create_centerface_train_label()

        centerface = {
            "image": {"type": "bytes"},
            "annotation": {"type": "float32", "shape": [-1, 19]},
        }
        writer.add_schema(centerface, "centerface")

        for image_name in image_files:
            with open(image_name, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[image_name], dtype=np.float32)
            row = {"image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def test_data_to_mindrecord_byte_image(self, prefix="centerface.mindrecord", file_num=1):
        """Create MindRecord file."""
        if not os.path.exists(self.mindrecord_dir):
            os.mkdir(self.mindrecord_dir)

        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        image_files, img_file_dict, img_name_dict, img_id_dict = self.create_centerface_test_label()
        centerface_eval = {
            "image": {"type": "bytes"},
            "image_id": {"type": "int32"}
        }
        writer.add_schema(centerface_eval, "centerface_eval")

        eval_list = []
        for image_name in image_files:
            with open(image_name, 'rb') as f:
                img = f.read()
                img_id = img_id_dict[image_name]
                img_file = img_file_dict[image_name]
                img_name = img_name_dict[image_name]
                test_dict = {}
                test_dict['id'] = img_id
                test_dict['im_dir'] = img_file
                test_dict['im_name'] = img_name
                eval_list.append(test_dict)
                row = {"image": img, "image_id": img_id}
            writer.write_raw_data([row])
        writer.commit()

        with open('mindrecord_eval.json', 'w') as f:
            f.write(json.dumps(eval_list))

    def __call__(self, prefix='centerface.mindrecord'):
        """ Write mindrecord file """
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)
        if self.split == 'train':
            if os.path.exists(mindrecord_path + "0"):
                print("mindrecord train dataset is existed")
                return
            self.train_data_to_mindrecord_byte_image()
        elif self.split == 'test':
            if os.path.exists(mindrecord_path):
                print("mindrecord test dataset is existed")
                return
            self.test_data_to_mindrecord_byte_image()
